from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.backends import default_backend

# ======================
# ✅ 工具函数：智能保存文件（自动判断 bytes 还是 str）
# ======================
def save_to_file(filename, content, mode=None):
    if mode is None:
        mode = 'wb' if isinstance(content, bytes) else 'w'
    with open(filename, mode) as f:
        f.write(content)

# ======================
# ✅ 修复后的函数：bytes 转 C 风格十六进制（每行16字节，格式规整）
# ======================
def bytes_to_c_style_hex(byte_data):
    hex_strs = [f"0x{b:02X}" for b in byte_data]
    lines = []
    current_line = []
    
    for h in hex_strs:
        current_line.append(h)
        if len(current_line) == 16:
            lines.append(", ".join(current_line) + ",")
            current_line = []
    
    if current_line:
        lines.append(", ".join(current_line) + ",")
    
    return "\n".join(lines)

# ======================
# ✅ 主函数
# ======================
def generate_rsa_keypair_and_sign():
    print("[*] 正在生成 RSA 密钥对（3072 位）...\n")

    private_key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=3072,
        backend=default_backend()
    )
    public_key = private_key.public_key()

    # ===== 1. 私钥 DER (PKCS#1)
    der_private_key = private_key.private_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PrivateFormat.TraditionalOpenSSL,
        encryption_algorithm=serialization.NoEncryption()
    )
    print("[+] ✅ 私钥 (DER 格式，PKCS#1) 已生成")
    print(bytes_to_c_style_hex(der_private_key))
    print("\n" + "="*80 + "\n")
    save_to_file("RSA_private_key_3072.der", der_private_key)

    # ===== 2. 公钥 DER (PKCS#1)
    der_public_key = public_key.public_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PublicFormat.PKCS1
    )
    print("[+] ✅ 公钥 (DER 格式，PKCS#1) 已生成")
    print(bytes_to_c_style_hex(der_public_key))
    print("\n" + "="*80 + "\n")
    save_to_file("RSA_public_key_3072.der", der_public_key)

    # ===== 3. 私钥 PEM (PKCS#8) —— 标准现代私钥格式
    pem_private_key_p8 = private_key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption()
    )
    print("[+] ✅ 私钥 PEM (PKCS#8 格式，即 BEGIN PRIVATE KEY) 已生成")
    save_to_file("RSA_private_key_3072.pem", pem_private_key_p8)

    # # ===== 4. 公钥 PEM (PKCS#8 / SubjectPublicKeyInfo) —— 标准通用公钥格式
    # pem_public_key_p8 = public_key.public_bytes(
    #     encoding=serialization.Encoding.PEM,
    #     format=serialization.PublicFormat.SubjectPublicKeyInfo
    # )
    # print("[+] ✅ 公钥 PEM (PKCS#8 / SubjectPublicKeyInfo, 即 BEGIN PUBLIC KEY) 已生成")
    # save_to_file("RSA_public_key_3072_PKCS8.pem", pem_public_key_p8)

    # ===== 5. 【新增】公钥 PEM (PKCS#1 / 传统 RSA 公钥格式) —— 某些工具可能需要
    pem_public_key_p1 = public_key.public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.PKCS1
    )
    print("[+] ✅ 公钥 PEM (PKCS#1 格式，即 BEGIN RSA PUBLIC KEY) 已生成")
    save_to_file("RSA_public_key_3072.pem", pem_public_key_p1)

    # ===== 6. 提取私钥组件（N, E, D, P, Q...）
    private_numbers = private_key.private_numbers()
    public_numbers = private_numbers.public_numbers

    N = public_numbers.n
    E = public_numbers.e
    D = private_numbers.d
    P = private_numbers.p
    Q = private_numbers.q
    DP = private_numbers.dmp1
    DQ = private_numbers.dmq1
    QP = private_numbers.iqmp

    private_key_txt = f"""N = {int_to_hex(N)}
E = {int_to_hex(E)}
D = {int_to_hex(D)}
P = {int_to_hex(P)}
Q = {int_to_hex(Q)}
DP = {int_to_hex(DP)}
DQ = {int_to_hex(DQ)}
QP = {int_to_hex(QP)}
"""
    save_to_file("RSA_private_key_components.txt", private_key_txt)

    public_key_txt = f"""N = {int_to_hex(N)}
E = {int_to_hex(E)}
"""
    save_to_file("RSA_public_key_components.txt", public_key_txt)

    # ===== 7. 签名 & 验签（你原来的逻辑，保持不变）
    message = [0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x20, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21]
    original_data = bytes(message)

    print("[*] 使用私钥对数据进行签名...")
    signature = private_key.sign(
        original_data,
        padding.PKCS1v15(),
        hashes.SHA256()
    )
    print("[+] ✅ 签名已生成")
    print(bytes_to_c_style_hex(signature))
    print("\n" + "="*80 + "\n")

    print("[*] 使用公钥验签...")
    try:
        public_key.verify(signature, original_data, padding.PKCS1v15(), hashes.SHA256())
        print("[✅] 验签成功！签名有效。")
    except Exception as e:
        print("[❌] 验签失败！")
        print(f"错误: {e}")

    print("\n[*] 所有操作完成。")

# ======================
# ✅ 工具函数：整数转十六进制大写
# ======================
def int_to_hex(n):
    return format(n, 'X')

# ======================
# ✅ 入口
# ======================
if __name__ == "__main__":
    generate_rsa_keypair_and_sign()